from pathlib import Path
import seaborn as sns
from scipy.stats import pearsonr
import itertools
from scipy.optimize import curve_fit
from scipy.special import erf
import efel
import matplotlib.pyplot as plt
from datareuse import Reuse
import numpy as np
import sys
import pandas as pd
from bluepyparallel import evaluate
from matplotlib.backends.backend_pdf import PdfPages
import extra_features

import pyabf


def plot_corr(corr_df):
    """Plot corr matrix."""
    plt.figure(figsize=(8, 8))
    ax = plt.gca()

    sns.heatmap(
        data=corr_df,
        ax=ax,
        vmin=0.8,
        vmax=-0.8,
        cmap="coolwarm",
        linewidths=0.5,
        linecolor="k",
        cbar_kws={"label": "pearson", "shrink": 0.3},
        xticklabels=True,
        yticklabels=True,
        square=True,
    )
    plt.tight_layout()


def get_2d_correlations(df):
    tuples = itertools.combinations(df.columns, 2)

    corr_df = pd.DataFrame(index=df.columns, columns=df.columns, dtype=float)
    for x, y in tuples:
        dx, dy = df[x].to_numpy(), df[y].to_numpy()
        mx = dx == 0
        my = dy == 0
        dx = dx[~mx & ~my]
        dy = dy[~mx & ~my]
        corr = pearsonr(dx, dy)[0]
        corr_df.loc[x, y] = corr
        corr_df.loc[y, x] = corr
        corr_df.loc[x, x] = 0.0
        corr_df.loc[y, y] = 0.0
    return corr_df


def abf_reader(in_data):
    """Reader for the .abf files obtained from the Broad Institute, Guoping Feng's lab."""
    abf = pyabf.ABF(in_data)
    data = []
    for sweep_id in range(abf.sweepCount):
        abf.setSweep(sweep_id)
        trace_data = {
            "voltage": np.array(abf.sweepY),
            "current": np.array(abf.sweepC),
            "time": np.array(abf.sweepX),
            "dt": 1.0 / abf.dataRate,
            "protocol_name": abf.protocol,
        }
        data.append(trace_data)
    return data


def compute_feature(row, with_plot=True):
    data = abf_reader(row["path"])
    if len(data) != 1:
        rebound_idx = [
            i for i, d in enumerate(data) if d["protocol_name"] == "ReboundBurstProtocol"
        ]
        # here we have some trace with rebound burst
        if len(rebound_idx):
            data = data[rebound_idx[0]]
        else:
            # this means its not bursting trace, but IC protocol for tonic
            return {
                "holding": None,
                "all_burst_number": None,
                "spike_count": None,
                "runaway": None,
                "tonic_after_burst": None,
                "spike_width2": None,
                "peak_voltage": None,
                "spikes_per_burst": None,
                "burst_mean_freq": None,
                "inv_first_ISI": None,
                "AP2_AP1_peak_diff": None,
                "AHP_depth_abs": None,
                "time_to_first_spike": None,
                "postburst_min_values": None,
            }
    else:
        data = data[0]

    time = data["time"] * 1000
    voltage = data["voltage"]

    # if there is no hyperpol, unusable trace
    if min(data["current"] > -1.0):
        return {
            "holding": None,
            "all_burst_number": None,
            "spike_count": None,
            "runaway": None,
            "tonic_after_burst": None,
            "spike_width2": None,
            "peak_voltage": None,
            "spikes_per_burst": None,
            "burst_mean_freq": None,
            "inv_first_ISI": None,
            "AP2_AP1_peak_diff": None,
            "AHP_depth_abs": None,
            "time_to_first_spike": None,
            "postburst_min_values": None,
        }

    stim_start = time[data["current"] <= -1][-1]
    efel_trace = {
        "T": time,
        "V": voltage,
        "stim_start": [stim_start],
        "stim_end": [15000],  # time[-1]], # fix the time to cut longer traces
    }
    # get more consistent burst number feature
    all_burst_number = efel.get_feature_values([efel_trace], ["all_burst_number"])[0][
        "all_burst_number"
    ][0]

    tonic_after_burst = efel.get_feature_values([efel_trace], ["tonic_after_burst"])[0][
        "tonic_after_burst"
    ][0]
    spike_count = efel.get_feature_values([efel_trace], ["spike_count"])[0]["spike_count"][0]

    burst_runaway = efel.get_feature_values([efel_trace], ["burst_runaway"])[0]["burst_runaway"][0]
    spike_width2 = efel.get_feature_values([efel_trace], ["spike_width2"])[0]["spike_width2"]
    peak_voltage = efel.get_feature_values([efel_trace], ["peak_voltage"])[0]["peak_voltage"]
    spikes_per_burst = efel.get_feature_values([efel_trace], ["spikes_per_burst"])[0][
        "spikes_per_burst"
    ]
    burst_mean_freq = efel.get_feature_values([efel_trace], ["burst_mean_freq"])[0][
        "burst_mean_freq"
    ]
    inv_first_ISI = efel.get_feature_values([efel_trace], ["inv_first_ISI"])[0]["inv_first_ISI"]
    AP2_AP1_peak_diff = efel.get_feature_values([efel_trace], ["AP2_AP1_peak_diff"])[0][
        "AP2_AP1_peak_diff"
    ]
    AHP_depth_abs = efel.get_feature_values([efel_trace], ["AHP_depth_abs"])[0]["AHP_depth_abs"]

    time_to_first_spike = efel.get_feature_values([efel_trace], ["time_to_first_spike"])[0][
        "time_to_first_spike"
    ]
    postburst_min_values = efel.get_feature_values([efel_trace], ["postburst_min_values"])[0][
        "postburst_min_values"
    ]

    efel_trace["stim_start"] = [0.1]
    efel_trace["stim_end"] = [time[data["current"] < -1][0]]
    holding_voltage = efel.get_feature_values([efel_trace], ["voltage_base"])[0]["voltage_base"][0]

    if with_plot:
        plt.figure(figsize=(15, 5))
        plt.plot(time, voltage, lw=0.8)
        plt.axhline(holding_voltage)
        plt.suptitle(
            f"""holding v: {holding_voltage}, burst number: {all_burst_number},
            runaway: {burst_runaway}, #spikes: {spike_count}, tonic: {tonic_after_burst}"""
        )
        plt.gca().set_xlim(0, 15000)
        folder = f"{row['month']}_{row['day']}_{row['cell']}"
        Path(f"traces/{row['cell_type']}/{folder}").mkdir(parents=True, exist_ok=True)
        plt.savefig(f"traces/{row['cell_type']}/{folder}/trace_{row['rec_id']}.pdf")
        plt.close()
    return {
        "holding": holding_voltage,
        "all_burst_number": all_burst_number,
        "spike_count": spike_count,
        "runaway": burst_runaway,
        "tonic_after_burst": tonic_after_burst,
        "spike_width2": min(np.mean(spike_width2), 2) if spike_width2 is not None else 0,
        "peak_voltage": np.mean(peak_voltage) if peak_voltage is not None else 0,
        "spikes_per_burst": np.mean(spikes_per_burst) if spikes_per_burst is not None else 0,
        "burst_mean_freq": np.mean(burst_mean_freq) if burst_mean_freq is not None else 0,
        "inv_first_ISI": inv_first_ISI[0] if inv_first_ISI is not None else 0,
        "AP2_AP1_peak_diff": AP2_AP1_peak_diff[0] if AP2_AP1_peak_diff is not None else 0,
        "AHP_depth_abs": np.mean(AHP_depth_abs) if AHP_depth_abs is not None else 0,
        "time_to_first_spike": time_to_first_spike[0] if time_to_first_spike is not None else 0,
        "postburst_min_values": (
            np.mean(postburst_min_values) if postburst_min_values is not None else 0
        ),
    }


month_map = {
    "Jan": "01",
    "Feb": "02",
    "Mar": "03",
    "Apr": "04",
    "May": "05",
    "Jun": "06",
    "Jul": "07",
    "Aug": "08",
    "Sep": "09",
    "Oct": "10",
    "Nov": "11",
    "Dec": "12",
}


def get_data_df():
    """Get the list of traces to extract features."""
    data_path = Path("../../ephys_data")

    from Spp_files_per_cell import Spp_cells_metadata
    from Ecel_files_per_cell import Ecel_cells_metadata
    from Burst_runaway_files_per_cell import Burst_runaway_cells_metadata

    Path("traces").mkdir(exist_ok=True)
    all_data = {"month": [], "day": [], "cell": [], "rec_id": [], "path": [], "cell_type": []}
    for cell_type, type_data in zip(
        ["spp", "ecel", "runaway"],
        [Spp_cells_metadata, Ecel_cells_metadata, Burst_runaway_cells_metadata],
    ):
        for month, month_data in type_data.items():
            _month = month_map[month]
            for day, day_data in month_data.items():
                for cell, data in day_data.items():
                    for i in data["files_ids"]:
                        for year in ["2020", "2021"]:
                            path = Path(
                                data_path
                                / month
                                / day
                                / f"{year}_{_month}_{int(day):02d}_{i:04d}.abf"
                            )
                            if path.exists():
                                all_data["month"].append(month)
                                all_data["day"].append(day)
                                all_data["cell"].append(cell)
                                all_data["rec_id"].append(i)
                                all_data["path"].append(str(path))
                                all_data["cell_type"].append(cell_type)

    return pd.DataFrame.from_dict(all_data)


if __name__ == "__main__":
    data_df = get_data_df()
    print(data_df)
    with Reuse("feature_data_df.csv") as reuse:
        data_df = reuse(
            evaluate,
            data_df,
            compute_feature,
            new_columns=[
                ["holding", 0],
                ["all_burst_number", 0],
                ["spike_count", 0],
                ["runaway", 0],
                ["tonic_after_burst", 0],
                ["spike_width2", 0],
                ["peak_voltage", 0],
                ["spikes_per_burst", 0],
                ["burst_mean_freq", 0],
                ["inv_first_ISI", 0],
                ["AP2_AP1_peak_diff", 0],
                ["AHP_depth_abs", 0],
                ["time_to_first_spike", 0],
                ["postburst_min_values", 0],
            ],
            parallel_factory="multiprocessing",
        ).drop(columns="exception")
    # clean data a bit more
    # remove None cells, for which features could not be extracted
    data_df = data_df.dropna()

    for (month, day, cell), data in data_df.groupby(["month", "day", "cell"]):
        # set to info the runaway for cells with very few values, usually these are miss computed
        if len(data[data["runaway"] < np.inf]) < 0.05 * len(data):
            data_df.loc[data.index, "runaway"] = np.inf

    # outliers (e.g. very long time to first burst, burst during holding, or weird otherwise)
    outliers_list = [
        # ecel
        ["Dec", 14, "Cell#2", "45"],
        ["Dec", 14, "Cell#3", "-1"],
        ["Feb", 1, "Cell#1", "-1"],
        ["Feb", 1, "Cell#2", "-1"],
        ["Mar", 16, "Cell#1", "-1"],
        ["Mar", 16, "Cell#4", "-1"],
        ["Mar", 16, "Cell#5", "201"],
        ["Dec", 14, "Cell#1", "-1"],
        ["Dec", 14, "Cell#5", "180"],
        ["Dec", 14, "Cell#5", "181"],
        # spp
        ["Jul", 11, "Cell#2", "64"],
        ["Jul", 11, "Cell#2", "79"],
        ["Jul", 11, "Cell#3", "116"],
        ["Jul", 11, "Cell#3", "121"],
        ["Jul", 9, "Cell#4", "197"],
        ["Jul", 9, "Cell#4", "198"],
        ["Jul", 9, "Cell#4", "167"],
        ["Sep", 25, "Cell#1", "10"],
        ["Jul", 9, "Cell#2", "102"],
        ["Jul", 9, "Cell#2", "103"],
        ["Jul", 9, "Cell#2", "104"],
        ["Jul", 9, "Cell#2", "105"],
        ["Jul", 9, "Cell#2", "106"],
        ["Jul", 9, "Cell#2", "107"],
        ["Jul", 9, "Cell#2", "108"],
        ["Jul", 9, "Cell#2", "109"],
        ["Jul", 9, "Cell#2", "110"],
        ["Jul", 9, "Cell#2", "111"],
        ["Jul", 9, "Cell#2", "112"],
        ["Jul", 9, "Cell#2", "113"],
        ["Jul", 9, "Cell#2", "114"],
        ["Jul", 9, "Cell#2", "115"],
        ["Jul", 9, "Cell#2", "116"],
        # runaway
        ["Jul", 9, "Cell#1", "2"],
        ["Jul", 9, "Cell#1", "5"],
        ["Jul", 9, "Cell#1", "37"],
        ["Jul", 9, "Cell#1", "38"],
        ["Jul", 9, "Cell#1", "39"],
        ["Jul", 9, "Cell#1", "40"],
        ["Jul", 9, "Cell#1", "41"],
        ["Jul", 9, "Cell#1", "45"],
        ["Jul", 11, "Cell#1", "16"],
        ["Jul", 11, "Cell#1", "17"],
        ["Sep", 23, "Cell#6", "-1"],
    ]

    data_df["outlier"] = False
    for month, day, cell, rec_id in outliers_list:
        if rec_id == "-1":  # all traces for this cell are outliers
            data_df.loc[
                (data_df["month"] == month)
                & (data_df["day"].astype(int) == int(day))
                & (data_df["cell"] == cell),
                "outlier",
            ] = True
        else:
            data_df.loc[
                (data_df["month"] == month)
                & (data_df["day"] == day)
                & (data_df["cell"] == cell)
                & (data_df["rec_id"] == int(rec_id)),
                "outlier",
            ] = True

    data_df = data_df[~data_df["outlier"]]
    data_df.to_csv("feature_data_df_clean.csv")

    # sort by max burst number
    df_burst = data_df.groupby(["month", "day", "cell", "cell_type"]).max("all_burst_number")
    print(df_burst.sort_values(by="all_burst_number"))

    # sort by min runaway
    df_runaway = data_df.groupby(["month", "day", "cell", "cell_type"]).min("runaway")
    print(df_runaway.sort_values(by="runaway"))
    popt_df = pd.DataFrame()
    i = 0
    for cell_type, _data_df in data_df.groupby("cell_type"):
        fig, ax = plt.subplots()
        with PdfPages(f"burst_features_{cell_type}.pdf") as pdf:
            for (month, day, cell), data in _data_df.groupby(["month", "day", "cell"]):

                plt.figure()
                plt.scatter(
                    data["holding"],
                    data["all_burst_number"],
                    c="k",
                )
                # plt.scatter(
                #    data["holding"][data["spike_count"] == 0],
                #    data["all_burst_number"][data["spike_count"] == 0],
                #    label="no spike",
                # )
                # plt.scatter(
                #    data["holding"][data["tonic_after_burst"] > 0],
                #    data["all_burst_number"][data["tonic_after_burst"] > 0],
                #    label="tonic",
                # )
                plt.scatter(
                    data["holding"][data["runaway"] < 0.1],
                    data["all_burst_number"][data["runaway"] < 0.1],
                    marker="+",
                    c="r",
                    label="runaway",
                )

                x, y = data["holding"].to_numpy(), data["all_burst_number"].to_numpy()
                y = y[np.argsort(x)]
                x = x[np.argsort(x)]

                def func(x, a, b, c, d, e):
                    return a * np.exp(-1 / b * (x - c) ** 2) * (1 + erf((x - c) * d)) + e

                try:
                    popt, pcov = curve_fit(
                        func, x, y, bounds=([0, 20, -100, -2, 0], [100, 180, -40, 2, 1])
                    )
                    perr = np.sqrt(np.diag(pcov))
                    popt_df.loc[i, "amp"] = popt[0]
                    popt_df.loc[i, "width"] = popt[1]
                    popt_df.loc[i, "center"] = popt[2]
                    popt_df.loc[i, "skew"] = popt[3]
                    popt_df.loc[i, "shift"] = popt[4]
                    popt_df.loc[i, "cell_type"] = cell_type
                    i += 1
                    _x = np.linspace(-80, -50, 100)
                    plt.plot(_x, func(_x, *popt), c="k")
                    ax.plot(_x, func(_x, *popt), c="k")
                except:
                    print("fail")

                plt.ylabel("burst number")
                plt.gca().set_xlim(-80, -45)
                plt.gca().set_ylim(0, 35)
                plt.legend()
                """
                plt.twinx()

                plt.ylabel("runaway")
                data["runaway"] = np.clip(data["runaway"], 1e-3, 1e2)
                plt.scatter(data["holding"], data["runaway"], marker="+", c="r", label="runaway")
                plt.axhline(0.2, c="r", label="runaway threshold")

                plt.yscale("log")
                plt.legend()
                plt.xlabel("mV")
                """
                plt.suptitle(f"{month} {day} {cell}")
                if month == "Sep" and day == 25 and cell == "Cell#5":
                    plt.savefig(f"single_burst_curve_{cell_type}.pdf")
                if month == "Jul" and day == 9 and cell == "Cell#1":
                    plt.savefig(f"single_burst_curve_{cell_type}.pdf")
                if month == "Dec" and day == 15 and cell == "Cell#2":
                    plt.savefig(f"single_burst_curve_{cell_type}.pdf")

                pdf.savefig()
                plt.close()

            ax.set_ylabel("burst number")
            ax.set_xlabel("holding")
            ax.set_xlim(-80, -45)
            ax.set_ylim(0, 35)
            fig.savefig(f"burst_{cell_type}.pdf")

    features = [
        "all_burst_number",
        "spikes_per_burst",
        "burst_mean_freq",
        "peak_voltage",
        "inv_first_ISI",
        "AP2_AP1_peak_diff",
        "AHP_depth_abs",
        "time_to_first_spike",
        "postburst_min_values",
    ]

    fig, axs = plt.subplots(1, len(features), figsize=(1.5 * len(features), 3))
    mask_ecel = data_df["cell_type"] == "ecel"
    mask_spp = data_df["cell_type"] == "spp"
    mask_runaway = data_df["cell_type"] == "runaway"
    for ax, feat in zip(axs, features):
        d = data_df[feat]
        d = d[d != 0]
        if feat.endswith("time_to_first_spike"):
            d = np.clip(d, -10, 800)
        ax.hist(
            d[mask_ecel],
            bins=30,
            orientation="horizontal",
            color="C0",
            histtype="step",
            range=(d.min(), d.max()),
        )
        ax.hist(
            d[mask_spp],
            bins=30,
            orientation="horizontal",
            color="C1",
            histtype="step",
            range=(d.min(), d.max()),
        )
        ax.hist(
            d[mask_runaway],
            bins=30,
            orientation="horizontal",
            color="C2",
            histtype="step",
            range=(d.min(), d.max()),
        )
        ax.axhline(d.mean(), c="k", label="exp. mean")
        ax.axhline(d.mean() - d.std(), c="k", ls="--")
        ax.axhline(d.mean() + d.std(), c="k", ls="--")
        ax.spines.right.set_visible(False)
        ax.spines.top.set_visible(False)
        ax.set_ylabel(feat.split(".")[-1])  # ALL_LABELS.get(feat, feat))
        ax.set_xlabel("# recordings")

    plt.tight_layout()
    plt.savefig("feature_distributions.pdf")
    data_df = data_df[data_df["time_to_first_spike"] > 0]
    data_df = data_df[data_df["time_to_first_spike"] < 500]
    corr_df = get_2d_correlations(data_df[features])
    plot_corr(corr_df)
    plt.savefig("2d_corr.pdf")

    plt.figure(figsize=(4, 4))
    dy, dx = data_df["inv_first_ISI"], data_df["burst_mean_freq"]
    mx = dx == 0
    my = dy == 0
    dx = dx[~mx & ~my]
    dy = dy[~mx & ~my]
    plt.scatter(dx, dy, s=2, c="k")
    plt.ylabel("inv. first ISI")
    plt.xlabel("burst mean freq")
    plt.tight_layout()
    plt.savefig("corr_1.pdf")

    plt.figure(figsize=(4, 4))
    dy, dx = data_df["spikes_per_burst"], data_df["burst_mean_freq"]
    mx = dx == 0
    my = dy == 0
    dx = dx[~mx & ~my]
    dy = dy[~mx & ~my]
    plt.scatter(dx, dy, s=2, c="k")
    plt.ylabel("spikes per burst")
    plt.xlabel("burst mean freq")
    plt.tight_layout()
    plt.savefig("corr_2.pdf")

    plt.figure(figsize=(4, 4))
    dy, dx = data_df["postburst_min_values"], data_df["time_to_first_spike"]
    mx = dx == 0
    my = dy == 0
    dx = dx[~mx & ~my]
    dy = dy[~mx & ~my]
    plt.scatter(dx, dy, s=2, c="k")
    plt.ylabel("postburst min values")
    plt.xlabel("time to first spike")
    plt.tight_layout()
    plt.savefig("corr_3.pdf")

    plt.figure(figsize=(4, 4))
    dy, dx = data_df["inv_first_ISI"], data_df["AHP_depth_abs"]
    mx = dx == 0
    my = dy == 0
    dx = dx[~mx & ~my]
    dy = dy[~mx & ~my]
    plt.scatter(dx, dy, s=2, c="k")
    plt.ylabel("inv. first ISI")
    plt.xlabel("AHP depth abs")
    plt.tight_layout()
    plt.savefig("corr_4.pdf")

    plt.figure(figsize=(4, 4))
    dy, dx = data_df["inv_first_ISI"], data_df["AP2_AP1_peak_diff"]
    mx = dx == 0
    my = dy == 0
    dx = dx[~mx & ~my]
    dy = dy[~mx & ~my]
    plt.scatter(dx, dy, s=2, c="k")
    plt.ylabel("inv. first ISI")
    plt.xlabel("AP2-AP1 peak diff")
    plt.tight_layout()
    plt.savefig("corr_5.pdf")

    plt.figure(figsize=(3, 4))
    d = df_runaway.reset_index()
    d["runaway"] = np.clip(d["runaway"], 1e-2, 10)
    sns.swarmplot(data=d, x="cell_type", y="runaway", order=["ecel", "spp", "runaway"])
    sns.boxplot(
        data=d,
        x="cell_type",
        y="runaway",
        order=["ecel", "spp", "runaway"],
        showfliers=False,
        fill=False,
        color="k",
    )
    plt.yscale("log")
    plt.tight_layout()
    plt.savefig("strip_runaway.pdf")

    plt.figure(figsize=(3, 4))
    sns.swarmplot(
        data=df_burst.reset_index(),
        x="cell_type",
        y="all_burst_number",
        c="k",
        order=["ecel", "spp", "runaway"],
    )
    sns.boxplot(
        data=df_burst.reset_index(),
        x="cell_type",
        y="all_burst_number",
        order=["ecel", "spp", "runaway"],
        showfliers=False,
        fill=False,
        color="k",
    )
    plt.tight_layout()
    plt.savefig("strip_burst_number.pdf")
    for tpe in ["amp", "width", "center", "skew", "shift"]:
        plt.figure(figsize=(3, 4))
        sns.swarmplot(
            data=popt_df,
            x="cell_type",
            y=tpe,
            c="k",
            order=["ecel", "spp", "runaway"],
        )
        sns.boxplot(
            data=popt_df,
            x="cell_type",
            y=tpe,
            order=["ecel", "spp", "runaway"],
            showfliers=False,
            fill=False,
            color="k",
        )
        plt.tight_layout()
        plt.savefig(f"bust_curve_{tpe}.pdf")
